{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SPA against RSA on XMEGA (8-bit implementation)\n",
    "\n",
    "Supported setups:\n",
    "\n",
    "SCOPES:\n",
    "\n",
    "* OPENADC\n",
    "\n",
    "PLATFORMS:\n",
    "\n",
    "* CWLITEXMEGA or CW303\n",
    "\n",
    "Note this *only* works with an XMEGA target. This tutorial does not work with any other platforms. This is because the RSA implementation in use is `avr-crypto-lib`, which is has AVR assembly code to accelerate certain routines. A later tutorial will demonstrate a similar (but not *exactly* the same) attack on MBED-TLS RSA implementation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SCOPETYPE = 'OPENADC'\n",
    "PLATFORM = 'CWLITEXMEGA'\n",
    "CRYPTO_TARGET = 'AVRCRYPTOLIB'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Firmware"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Like before, we'll need to setup our `PLATFORM`, then build the firmware if you prefer to \"cook your own\". Note that a provided .hex file can be used if you don't have the avr-gcc compiler installed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash -s \"$PLATFORM\" \"$CRYPTO_TARGET\"\n",
    "cd ../../../firmware/mcu/simpleserial-rsa\n",
    "make PLATFORM=$1 CRYPTO_TARGET=$2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Setup is the same as usual. Note you don't need to rebuild the RSA implementation if you don't have the AVR compiler available. Check the firmware directory to see the filename for the XMEGA target."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run \"../../Setup_Scripts/Setup_Generic.ipynb\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fw_path = '../../../firmware/mcu/simpleserial-rsa/simpleserial-rsa-CWLITEXMEGA.hex'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cw.program_target(scope, prog, fw_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Communicating With Target and Testing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The target is different than normal. We don't use the key commands, instead whatever is sent as the 'text' becomes the key in use. Due to the limited capture length, we need to capture a smaller than normal RSA key. While we might normally have a 512/1024/2048 bit key, we are going to capture only part of it.\n",
    "\n",
    "Doing so means sending a message with leading 0's, which are not processed. By only using the lower 16 bits of our message, we'll effectively process a 16-bit key."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scope.clock.adc_src = \"clkgen_x1\"\n",
    "scope.adc.samples = 24000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To compare some different keys, we can perform two captures. You should see for example an obvious difference in the power traces below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib notebook\n",
    "import matplotlib.pylab as plt\n",
    "text = bytearray([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x80, 0x00])\n",
    "def capture_RSA_trace(scope, target, text):\n",
    "    scope.arm()\n",
    "    target.simpleserial_write('p', text)\n",
    "    ret = scope.capture()\n",
    "    if ret:\n",
    "        return None\n",
    "    target.simpleserial_wait_ack()\n",
    "    return scope.get_last_trace()\n",
    "\n",
    "trace = capture_RSA_trace(scope, target, text)\n",
    "plt.plot(trace, 'r')\n",
    "text = bytearray([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x81, 0x00])\n",
    "trace = capture_RSA_trace(scope, target, text)\n",
    "plt.plot(trace, 'b')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Breaking RSA\n",
    "Now that we have such a target we can get power traces from, how to break RSA? The easiest way is actually with a \"single-trace\" attack. Let's capture a single RSA trace here:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key = bytearray([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x8A, 0xB0])\n",
    "trace = capture_RSA_trace(scope, target, key)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can refer to the original RSA code. But our objective is to compare iterations through the loop - an easy way to do this might be to use some \"reference\" part of the waveform and sweep it through. Basically, we need to compare two sections in the waveform and see how closely they match at that point in time. Take a look at this original power trace:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib notebook\n",
    "import matplotlib.pylab as plt\n",
    "plt.plot(trace, 'r')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will select as a \"reference\" some small section. For example take a look at some points here, does this maybe look unique?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib notebook\n",
    "import matplotlib.pylab as plt\n",
    "ref_trace = trace[3600:4100]\n",
    "plt.plot(ref_trace, 'b')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you slide the \"offset\" along, you'll see the resulting pattern be slid along as well. The SAD is calculated and printed.\n",
    "\n",
    "**NB**: An artifact of Jupyter is that once you run this cell, the `%matplotlib notebook` magic may no longer work. If you want to get interactive graphs again, you'll need to restart the kernel & not run this cell. Restarting the kernel will require you to capture data again. You can easily do the remained of the tutorial without the interactive graphs (meaning it's not needed to restart the kernel after this cell)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ipywidgets as widgets\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import display\n",
    "import numpy as np\n",
    "%matplotlib inline\n",
    "\n",
    "@widgets.interact(offset=(0, len(trace)))\n",
    "def plotsad(offset=3600):\n",
    "    plt.plot(trace, 'r')\n",
    "    plt.plot(range(offset, offset+len(ref_trace)), ref_trace, 'b', alpha=0.6)\n",
    "    plt.figure()\n",
    "    diff = ref_trace-trace[offset:(offset+len(ref_trace))]\n",
    "    plt.plot(range(offset, offset+len(ref_trace)), diff, 'g', alpha=0.6)\n",
    "    print(np.sum(abs(diff)))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "start = 3600\n",
    "rsa_one = trace[start:(start+500)]      \n",
    "diffs = []\n",
    "for i in range(0, len(trace)-len(rsa_one)):\n",
    "    diff = trace[i:(i+len(rsa_one))] - rsa_one    \n",
    "    diffs.append(np.sum(abs(diff)))\n",
    "    \n",
    "plt.figure()\n",
    "plt.plot(diffs)\n",
    "plt.title('SAD Match for RSA')\n",
    "plt.ylabel('SAD Difference')\n",
    "plt.xlabel('Offset')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Find a good threshold based on what seems to result in a \"SAD match\". Use that number in the following block for the np.where (here it's set to 10.0):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "times = np.where(np.array(diffs) < 10.0)[0]\n",
    "deltalist = []\n",
    "for i in range(0, len(times)-1):\n",
    "    delta = times[i+1] - times[i]\n",
    "    deltalist.append(delta)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And we can then plot the time deltas:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(deltalist, range(0, len(deltalist)), 'or')\n",
    "plt.grid(True)\n",
    "plt.title('A Learned Comparison of RSA Execution Time')\n",
    "plt.ylabel('Processing Bit Number')\n",
    "plt.xlabel('Time Delta (based on SAD Match)')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Putting it all together - let's try and use a good difference in the processing time to figure out if a bit is a 0 or 1. In my example the above graph seems to have a nice split around 1400 cycle time delta. Yours might change with different compilers!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "key = \"\"\n",
    "times = np.where(np.array(diffs) < 10)[0]\n",
    "for i in range(0, len(times)-1):\n",
    "    delta = times[i+1] - times[i]\n",
    "    #print(delta)\n",
    "    if delta > 1400:\n",
    "        key += \"1\"\n",
    "    else:\n",
    "        key += \"0\"\n",
    "key += \"0\"\n",
    "print(\"%04X\"%int(key, 2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Hopefully that recovered the encryption key you set earlier! The last caveat is the *last bit* isn't recovered. Can you figure out a way to recover it? Why isn't it recovered?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plotting Bandwidth-Limited Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Two additional notes that you might have heard about before:\n",
    " \n",
    " * RSA attacks often require a lot less bandwidth. \n",
    " * Can we just \"read off\" the 1 vs 0?\n",
    " \n",
    "In this example, we didn't see the obvious '1 vs 0'. But some simple filtering can help us recover it. Experiment with the bandwidth (set by the `bw` variable) of this low-pass filter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.signal import butter, lfilter\n",
    "\n",
    "bw = 0.001\n",
    "\n",
    "b, a = butter(3, bw, btype='low')\n",
    "y = lfilter(b, a, trace)\n",
    "\n",
    "print(7.37E6 * 0.001)\n",
    "\n",
    "plt.plot(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Can you adjust the red & black lines such that peaks for '1' are above the black line, and peaks for '0' are above the red & black line? Once you get these good values, we can build *another* way of recovering the RSA private key."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ipywidgets as widgets\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import display\n",
    "%matplotlib inline\n",
    "\n",
    "#You may need to adjust the ranges!!!\n",
    "@widgets.interact(low=(0.0, 0.15, 0.002),\n",
    "                  high=(0.0,0.15, 0.002))\n",
    "def plotlimits(low=0.085, high=0.095):\n",
    "    plt.plot(y, 'b')\n",
    "    plt.plot([0, len(y)], [low, low], 'r')\n",
    "    plt.plot([0, len(y)], [high, high], 'k')\n",
    "    \n",
    "    print(\"Low: {:f}, High: {:f}\".format(low, high))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In fact, we can use some simple Python magic to avoid needing to do the dual threshold. We notice the real goal is just finding the local maximums, and check if it's above/below that maximum threshold. While the good news is later versions of scipy have this built in for us! So we can check if the local maxima functions make sense:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.signal import argrelextrema\n",
    "\n",
    "# for local maxima - see this\n",
    "local_max = argrelextrema(y, np.greater)\n",
    "print(local_max)\n",
    "#Note we really need to index to get the actual list\n",
    "print(local_max[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And with a bit of finagling, we can convert that into a key. You'll need to modify the threshold that is hard-coded in this example based on what you found made sense earlier:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key_filt = \"\"\n",
    "for m in local_max[0]:\n",
    "    if y[m] > 0.11:\n",
    "        key_filt += \"1\"\n",
    "    else:\n",
    "        key_filt += \"0\"\n",
    "key_filt += \"0\"\n",
    "print(\"%04X\"%int(key_filt, 2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Yipee! Hopefully this worked out and recovered the same key as above. How low of a bandwidth can you work with and still recover the key?\n",
    "\n",
    "Remember that the original power trace had perfect synchronization. So in real life you might need a higher bandwidth, but you should see compared to a DPA attack that it's still a lot lower."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This tutorial has demonstrated the use of the power side-channel for performing RSA 8-bit attacks. We attacked it both using a SAD match to find the interesting points, and by performing a bandwidth-specific filter to make it more obvious when sections are 1 vs 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scope.dis()\n",
    "target.dis()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert int(key_filt, 2) == 0x8AB0, \"Failed to break key with filter, adjust maximum\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert int(key, 2) == 0x8AB0, \"Failed to break key with SAD Match\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
